import torch
import torchaudio

def get_LFCC_80(x:torch.Tensor):
    x = x.reshape(1, -1)
    lfcc = torchaudio.transforms.LFCC(n_lfcc=80)
    y = lfcc(x)
    delta = torchaudio.functional.compute_deltas(y)
    delta2 = torchaudio.functional.compute_deltas(delta)
    lfccs = torch.concat([y, delta, delta2], dim=1)
    lfccs = torch.transpose(lfccs, 1, 2)
    return lfccs


def get_LFCC_70(x:torch.Tensor):
    x = x.reshape(1, -1)
    lfcc = torchaudio.transforms.LFCC(n_lfcc=70)
    y = lfcc(x)
    delta = torchaudio.functional.compute_deltas(y)
    delta2 = torchaudio.functional.compute_deltas(delta)
    lfccs = torch.concat([y, delta, delta2], dim=1)
    lfccs = torch.transpose(lfccs, 1, 2)
    return lfccs

def get_LFCC_60(x:torch.Tensor):
    x = x.reshape(1, -1)
    lfcc = torchaudio.transforms.LFCC(n_lfcc=60)
    y = lfcc(x)
    delta = torchaudio.functional.compute_deltas(y)
    delta2 = torchaudio.functional.compute_deltas(delta)
    lfccs = torch.concat([y, delta, delta2], dim=1)
    lfccs = torch.transpose(lfccs, 1, 2)
    return lfccs

def get_MFCC_30(x:torch.Tensor):
    x=x.reshape(1,-1)
    mfcc = torchaudio.transforms.MFCC(n_mfcc=30)
    y = mfcc(x)
    delta = torchaudio.functional.compute_deltas(y)
    delta2 = torchaudio.functional.compute_deltas(delta)
    mfccs = torch.concat([y, delta, delta2], dim=1)
    mfccs = torch.transpose(mfccs, 1, 2)
    return mfccs

def get_MFCC_40(x:torch.Tensor):
    x=x.reshape(1,-1)
    mfcc = torchaudio.transforms.MFCC(n_mfcc=40)
    y = mfcc(x)
    delta = torchaudio.functional.compute_deltas(y)
    delta2 = torchaudio.functional.compute_deltas(delta)
    mfccs = torch.concat([y, delta, delta2], dim=1)
    mfccs = torch.transpose(mfccs, 1, 2)
    return mfccs

def get_MFCC_80(x:torch.Tensor):
    x=x.reshape(1,-1)
    mfcc = torchaudio.transforms.MFCC(n_mfcc=80)
    y = mfcc(x)
    delta = torchaudio.functional.compute_deltas(y)
    delta2 = torchaudio.functional.compute_deltas(delta)
    mfccs = torch.concat([y, delta, delta2], dim=1)
    mfccs = torch.transpose(mfccs, 1, 2)
    return mfccs

def get_SPEC_3072(x:torch.Tensor):
    def torch_power_to_db(S, amin=1e-10, top_db=80.0):
        magnitude = S
        log_spec = 10.0 * torch.log10(torch.maximum(torch.Tensor([amin]), magnitude))
        log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
        return log_spec
    x = x.reshape(1, -1)
    y=torch.stft(x,n_fft=3072,win_length=1024,hop_length=256,
                 window=torch.hann_window(1024),return_complex=True)
    y1=torch.abs(y)**2
    spec=torch_power_to_db(y1)
    return spec

def get_SPEC_2048(x:torch.Tensor):
    def torch_power_to_db(S, amin=1e-10, top_db=80.0):
        magnitude = S
        log_spec = 10.0 * torch.log10(torch.maximum(torch.Tensor([amin]), magnitude))
        log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
        return log_spec
    x = x.reshape(1, -1)
    y=torch.stft(x,n_fft=2048,win_length=1024,hop_length=256,
                 window=torch.hann_window(1024),return_complex=True)
    y1=torch.abs(y)**2
    spec=torch_power_to_db(y1)
    return spec

def get_SPEC_1024(x:torch.Tensor):
    def torch_power_to_db(S, amin=1e-10, top_db=80.0):
        magnitude = S
        log_spec = 10.0 * torch.log10(torch.maximum(torch.Tensor([amin]), magnitude))
        log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
        return log_spec
    x = x.reshape(1, -1)
    y=torch.stft(x,n_fft=1024,win_length=1024,hop_length=256,
                 window=torch.hann_window(1024),return_complex=True)
    y1=torch.abs(y)**2
    spec=torch_power_to_db(y1)
    return spec